package com.kklt.kson.util;

import java.io.File;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Set;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;

/**
 * 类加载器工具类
 * @author lishouyu
 */
public final class ClassLoaderUtils {

  /**
   * 获取类加载器
   * @return
   */
  public static  ClassLoader getClassLoader(){
    return Thread.currentThread().getContextClassLoader();
  }

  /**
   * 根据全类名加载类
   * @param className 全类名
   * @param initialize 是否执行类的静态代码块
   * @return
   */
  public static Class<?> getClassByName(String className,boolean initialize){
    Class<?> cls = null;
    try {
     cls = Class.forName(className,initialize,getClassLoader());
    } catch (ClassNotFoundException e) {
      e.printStackTrace();
    }
    return cls;
  }

  /**
   *  加载包下的所有类
   * @param packageName 加载包下的所有类
   * @return
   */
  public static Set<Class<?>> getClassSet(String packageName){
    Set<Class<?>> classSet = new HashSet<>();
    try {
      Enumeration<URL> urls = getClassLoader().getResources(packageName.replaceAll("\\.", "/"));
      while (urls.hasMoreElements()){
        URL url = urls.nextElement();
        if (url !=null){
          String protocol = url.getProtocol();
          if ("file".equals(protocol)){
            addClass(url.getPath(),classSet,packageName);
          }else if ("jar".equals(protocol)){
            JarURLConnection jarURLConnection = (JarURLConnection)url.openConnection();
            if (jarURLConnection != null){
              JarFile jarFile = jarURLConnection.getJarFile();
              Enumeration<JarEntry> entries = jarFile.entries();
              while (entries.hasMoreElements()){
                JarEntry jarEntry = entries.nextElement();
                String name = jarEntry.getName();
                if (name.endsWith(".class")){
                  String className = name.substring(0, name.lastIndexOf(".")).replaceAll("/", ".");
                  Class<?> classByName = getClassByName(className, true);
                  classSet.add(classByName);
                }
              }
            }
          }
        }
      }
    } catch (IOException e) {
      e.printStackTrace();
    }
    return classSet;
  }


  /**
   * 递归加载包下的所有类
   * @param packagePath 包名
   */
  private static void addClass(String packagePath,Set<Class<?>> classSet , String packageName){
    File[] classFileAndDir = getClassFile(packagePath);
    if (classFileAndDir!=null){
      for (File file : classFileAndDir) {
        if (file.isDirectory()){
          //1. 对路径下的目录进行递归
          addClass(file.getPath(),classSet,packageName + "." + file.getName());
        }else{
          //2. 加载类
          //2.1 获取文件名 *.class
          String fileName = file.getName();
          //2.2 获取不带扩展名的文件名
          String classNameWithOutPackage = fileName.substring(0, fileName.lastIndexOf("."));
          //2.3 拼接得到类的带包名称的全类名
          String className = packageName + "." + classNameWithOutPackage;
          //2.4 反射加载类
          Class<?> classByName = getClassByName(className, true);
          //2.5 添加入
          classSet.add(classByName);
        }
      }
    }
  }

  /**
   * 加载该路径下所有的class文件和目录
   * @param filePath 文件路径
   */
  private static File[] getClassFile(String filePath) {
    return new File(filePath).listFiles(file -> (file.isFile() && file.getName().endsWith(".class")) || file.isDirectory());
  }
}
